1   /*                        __    __  __  __    __  ___
2    *                       \  \  /  /    \  \  /  /  __/
3    *                        \  \/  /  /\  \  \/  /  /
4    *                         \____/__/  \__\____/__/.ɪᴏ
5    * ᶜᵒᵖʸʳᶦᵍʰᵗ ᵇʸ ᵛᵃᵛʳ ⁻ ˡᶦᶜᵉⁿˢᵉᵈ ᵘⁿᵈᵉʳ ᵗʰᵉ ᵃᵖᵃᶜʰᵉ ˡᶦᶜᵉⁿˢᵉ ᵛᵉʳˢᶦᵒⁿ ᵗʷᵒ ᵈᵒᵗ ᶻᵉʳᵒ
6    */
7   package io.vavr.collection;
8   
9   import io.vavr.collection.JavaConverters.ChangePolicy;
10  import io.vavr.collection.JavaConverters.ListView;
11  import io.vavr.control.Option;
12  
13  import java.util.ArrayList;
14  import java.util.Collection;
15  import java.util.Objects;
16  import java.util.function.*;
17  
18  /**
19   * Internal class, containing helpers.
20   *
21   * @author Daniel Dietrich
22   */
23  final class Collections {
24  
25      // checks, if the *elements* of the given iterables are equal
26      static boolean areEqual(Iterable<?> iterable1, Iterable<?> iterable2) {
27          final java.util.Iterator<?> iter1 = iterable1.iterator();
28          final java.util.Iterator<?> iter2 = iterable2.iterator();
29          while (iter1.hasNext() && iter2.hasNext()) {
30              if (!Objects.equals(iter1.next(), iter2.next())) {
31                  return false;
32              }
33          }
34          return iter1.hasNext() == iter2.hasNext();
35      }
36  
37      static <T, C extends Seq<T>> C asJava(C source, Consumer<? super java.util.List<T>> action, ChangePolicy changePolicy) {
38          Objects.requireNonNull(action, "action is null");
39          final ListView<T, C> view = JavaConverters.asJava(source, changePolicy);
40          action.accept(view);
41          return view.getDelegate();
42      }
43  
44      @SuppressWarnings("unchecked")
45      static <T, S extends Seq<T>> Iterator<S> crossProduct(S empty, S seq, int power) {
46          if (power < 0) {
47              return Iterator.empty();
48          } else {
49              return Iterator.range(0, power)
50                      .foldLeft(Iterator.of(empty), (product, ignored) -> product.flatMap(el -> seq.map(t -> (S) el.append(t))));
51          }
52      }
53      
54      @SuppressWarnings("unchecked")
55      static <T, S extends IndexedSeq<T>> S dropRightUntil(S seq, Predicate<? super T> predicate) {
56          Objects.requireNonNull(predicate, "predicate is null");
57          for (int i = seq.length() - 1; i >= 0; i--) {
58              if (predicate.test(seq.get(i))) {
59                  return (S) seq.take(i + 1);
60              }
61          }
62          return (S) seq.take(0);
63      }
64  
65      @SuppressWarnings("unchecked")
66      static <T, S extends Seq<T>> S dropUntil(S seq, Predicate<? super T> predicate) {
67          Objects.requireNonNull(predicate, "predicate is null");
68          for (int i = 0; i < seq.length(); i++) {
69              if (predicate.test(seq.get(i))) {
70                  return (S) seq.drop(i);
71              }
72          }
73          return (S) seq.take(0);
74      }
75  
76      @SuppressWarnings("unchecked")
77      static <K, V> boolean equals(Map<K, V> source, Object object) {
78          if (source == object) {
79              return true;
80          } else if (source != null && object instanceof Map) {
81              final Map<K, V> map = (Map<K, V>) object;
82              if (source.size() != map.size()) {
83                  return false;
84              } else {
85                  try {
86                      return source.forAll(map::contains);
87                  } catch (ClassCastException e) {
88                      return false;
89                  }
90              }
91          } else {
92              return false;
93          }
94      }
95  
96      @SuppressWarnings("unchecked")
97      static <K, V> boolean equals(Multimap<K, V> source, Object object) {
98          if (source == object) {
99              return true;
100         } else if (source != null && object instanceof Multimap) {
101             final Multimap<K, V> multimap = (Multimap<K, V>) object;
102             if (source.size() != multimap.size()) {
103                 return false;
104             } else {
105                 try {
106                     return source.forAll(multimap::contains);
107                 } catch (ClassCastException e) {
108                     return false;
109                 }
110             }
111         } else {
112             return false;
113         }
114     }
115 
116     @SuppressWarnings("unchecked")
117     static <V> boolean equals(Seq<V> source, Object object) {
118         if (object == source) {
119             return true;
120         } else if (source != null && object instanceof Seq) {
121             final Seq<V> seq = (Seq<V>) object;
122             return seq.size() == source.size() && areEqual(source, seq);
123         } else {
124             return false;
125         }
126     }
127 
128     @SuppressWarnings("unchecked")
129     static <V> boolean equals(Set<V> source, Object object) {
130         if (source == object) {
131             return true;
132         } else if (source != null && object instanceof Set) {
133             final Set<V> set = (Set<V>) object;
134             if (source.size() != set.size()) {
135                 return false;
136             } else {
137                 try {
138                     return source.forAll(set::contains);
139                 } catch (ClassCastException e) {
140                     return false;
141                 }
142             }
143         } else {
144             return false;
145         }
146     }
147 
148     static <T> Iterator<T> fill(int n, Supplier<? extends T> supplier) {
149         Objects.requireNonNull(supplier, "supplier is null");
150         return tabulate(n, ignored -> supplier.get());
151     }
152 
153     static <C extends Traversable<T>, T> C fill(int n, Supplier<? extends T> s, C empty, Function<T[], C> of) {
154         Objects.requireNonNull(s, "s is null");
155         Objects.requireNonNull(empty, "empty is null");
156         Objects.requireNonNull(of, "of is null");
157         return tabulate(n, anything -> s.get(), empty, of);
158     }
159 
160     static <T, C, R extends Iterable<T>> Map<C, R> groupBy(Traversable<T> source, Function<? super T, ? extends C> classifier, Function<? super Iterable<T>, R> mapper) {
161         Objects.requireNonNull(classifier, "classifier is null");
162         Objects.requireNonNull(mapper, "mapper is null");
163         Map<C, R> results = LinkedHashMap.empty();
164         for (java.util.Map.Entry<? extends C, Collection<T>> entry : groupBy(source, classifier)) {
165             results = results.put(entry.getKey(), mapper.apply(entry.getValue()));
166         }
167         return results;
168 
169     }
170 
171     private static <T, C> java.util.Set<java.util.Map.Entry<C, Collection<T>>> groupBy(Traversable<T> source, Function<? super T, ? extends C> classifier) {
172         final java.util.Map<C, Collection<T>> results = new java.util.LinkedHashMap<>(source.isTraversableAgain() ? source.size() : 16);
173         for (T value : source) {
174             final C key = classifier.apply(value);
175             results.computeIfAbsent(key, k -> new ArrayList<>()).add(value);
176         }
177         return results.entrySet();
178     }
179 
180     // hashes the elements respecting their order
181     static int hashOrdered(Iterable<?> iterable) {
182         return hash(iterable, (acc, hash) -> acc * 31 + hash);
183     }
184 
185     // hashes the elements regardless of their order
186     static int hashUnordered(Iterable<?> iterable) {
187         return hash(iterable, (acc, hash) -> acc + hash);
188     }
189 
190     private static int hash(Iterable<?> iterable, IntBinaryOperator accumulator) {
191         if (iterable == null) {
192             return 0;
193         } else {
194             int hashCode = 1;
195             for (Object o : iterable) {
196                 hashCode = accumulator.applyAsInt(hashCode, Objects.hashCode(o));
197             }
198             return hashCode;
199         }
200     }
201 
202     static Option<Integer> indexOption(int index) {
203         return Option.when(index >= 0, index);
204     }
205 
206     // @param iterable may not be null
207     static boolean isEmpty(Iterable<?> iterable) {
208         return iterable instanceof Traversable && ((Traversable) iterable).isEmpty()
209                 || iterable instanceof Collection && ((Collection) iterable).isEmpty()
210                 || !iterable.iterator().hasNext();
211     }
212 
213     static <T> boolean isTraversableAgain(Iterable<? extends T> iterable) {
214         return (iterable instanceof Collection) ||
215                 (iterable instanceof Traversable && ((Traversable<?>) iterable).isTraversableAgain());
216     }
217 
218     @SuppressWarnings("unchecked")
219     static <K, V, K2, U extends Map<K2, V>> U mapKeys(Map<K, V> source, U zero, Function<? super K, ? extends K2> keyMapper, BiFunction<? super V, ? super V, ? extends V> valueMerge) {
220         Objects.requireNonNull(zero, "zero is null");
221         Objects.requireNonNull(keyMapper, "keyMapper is null");
222         Objects.requireNonNull(valueMerge, "valueMerge is null");
223         return source.foldLeft(zero, (acc, entry) -> {
224             final K2 k2 = keyMapper.apply(entry._1);
225             final V v2 = entry._2;
226             final Option<V> v1 = acc.get(k2);
227             final V v = v1.isDefined() ? valueMerge.apply(v1.get(), v2) : v2;
228             return (U) acc.put(k2, v);
229         });
230     }
231 
232     @SuppressWarnings("unchecked")
233     static <C extends Traversable<T>, T> C removeAll(C source, Iterable<? extends T> elements) {
234         Objects.requireNonNull(elements, "elements is null");
235         if (source.isEmpty()) {
236             return source;
237         } else {
238             final Set<T> removed = HashSet.ofAll(elements);
239             return removed.isEmpty() ? source : (C) source.filter(e -> !removed.contains(e));
240         }
241     }
242 
243     @SuppressWarnings("unchecked")
244     static <C extends Traversable<T>, T> C removeAll(C source, Predicate<? super T> predicate) {
245         Objects.requireNonNull(predicate, "predicate is null");
246         if (source.isEmpty()) {
247             return source;
248         } else {
249             return (C) source.filter(predicate.negate());
250         }
251     }
252 
253     @SuppressWarnings("unchecked")
254     static <C extends Traversable<T>, T> C removeAll(C source, T element) {
255         if (source.isEmpty()) {
256             return source;
257         } else {
258             return (C) source.filter(e -> !Objects.equals(e, element));
259         }
260     }
261 
262     @SuppressWarnings("unchecked")
263     static <C extends Traversable<T>, T> C retainAll(C source, Iterable<? extends T> elements) {
264         Objects.requireNonNull(elements, "elements is null");
265         if (source.isEmpty()) {
266             return source;
267         } else {
268             final Set<T> retained = HashSet.ofAll(elements);
269             return (C) source.filter(retained::contains);
270         }
271     }
272 
273     static <T> Iterator<T> reverseIterator(Iterable<T> iterable) {
274         if (iterable instanceof java.util.List) {
275             return reverseListIterator((java.util.List<T>) iterable);
276         } else if (iterable instanceof Seq) {
277             return ((Seq<T>) iterable).reverseIterator();
278         } else {
279             return List.<T>empty().pushAll(iterable).iterator();
280         }
281     }
282 
283     private static <T> Iterator<T> reverseListIterator(java.util.List<T> list) {
284         return new Iterator<T>() {
285             private final java.util.ListIterator<T> delegate = list.listIterator(list.size());
286 
287             @Override
288             public boolean hasNext() {
289                 return delegate.hasPrevious();
290             }
291 
292             @Override
293             public T next() {
294                 return delegate.previous();
295             }
296         };
297     }
298 
299     static <T, U, R extends Traversable<U>> R scanLeft(Traversable<? extends T> source,
300                                                        U zero, BiFunction<? super U, ? super T, ? extends U> operation, Function<Iterator<U>, R> finisher) {
301         Objects.requireNonNull(operation, "operation is null");
302         final Iterator<U> iterator = source.iterator().scanLeft(zero, operation);
303         return finisher.apply(iterator);
304     }
305 
306     static <T, U, R extends Traversable<U>> R scanRight(Traversable<? extends T> source,
307                                                         U zero, BiFunction<? super T, ? super U, ? extends U> operation, Function<Iterator<U>, R> finisher) {
308         Objects.requireNonNull(operation, "operation is null");
309         final Iterator<? extends T> reversedElements = reverseIterator(source);
310         return scanLeft(reversedElements, zero, (u, t) -> operation.apply(t, u), us -> finisher.apply(reverseIterator(us)));
311     }
312 
313     static <T, S extends Seq<T>> S shuffle(S source, Function<? super Iterable<T>, S> ofAll) {
314         if (source.length() <= 1) {
315             return source;
316         }
317 
318         final java.util.List<T> list = source.toJavaList();
319         java.util.Collections.shuffle(list);
320         return ofAll.apply(list);
321     }
322 
323     static void subSequenceRangeCheck(int beginIndex, int endIndex, int length) {
324         if (beginIndex < 0 || endIndex > length) {
325             throw new IndexOutOfBoundsException("subSequence(" + beginIndex + ", " + endIndex + "), length = " + length);
326         } else if (beginIndex > endIndex) {
327             throw new IllegalArgumentException("subSequence(" + beginIndex + ", " + endIndex + ")");
328         }
329     }
330 
331     static <T> Iterator<T> tabulate(int n, Function<? super Integer, ? extends T> f) {
332         Objects.requireNonNull(f, "f is null");
333         if (n <= 0) {
334             return Iterator.empty();
335         } else {
336             return new AbstractIterator<T>() {
337 
338                 int i = 0;
339 
340                 @Override
341                 public boolean hasNext() {
342                     return i < n;
343                 }
344 
345                 @Override
346                 protected T getNext() {
347                     return f.apply(i++);
348                 }
349             };
350         }
351     }
352 
353     static <C extends Traversable<T>, T> C tabulate(int n, Function<? super Integer, ? extends T> f, C empty, Function<T[], C> of) {
354         Objects.requireNonNull(f, "f is null");
355         Objects.requireNonNull(empty, "empty is null");
356         Objects.requireNonNull(of, "of is null");
357         if (n <= 0) {
358             return empty;
359         } else {
360             @SuppressWarnings("unchecked")
361             final T[] elements = (T[]) new Object[n];
362             for (int i = 0; i < n; i++) {
363                 elements[i] = f.apply(i);
364             }
365             return of.apply(elements);
366         }
367     }
368 
369     static <T, U extends Seq<T>, V extends Seq<U>> V transpose(V matrix, Function<Iterable<U>, V> rowFactory, Function<T[], U> columnFactory) {
370         Objects.requireNonNull(matrix, "matrix is null");
371         if (matrix.isEmpty() || (matrix.length() == 1 && matrix.head().length() <= 1)) {
372             return matrix;
373         } else {
374             return transposeNonEmptyMatrix(matrix, rowFactory, columnFactory);
375         }
376     }
377 
378     private static <T, U extends Seq<T>, V extends Seq<U>> V transposeNonEmptyMatrix(V matrix, Function<Iterable<U>, V> rowFactory, Function<T[], U> columnFactory) {
379         final int newHeight = matrix.head().size(), newWidth = matrix.size();
380         @SuppressWarnings("unchecked") final T[][] results = (T[][]) new Object[newHeight][newWidth];
381 
382         if (matrix.exists(r -> r.size() != newHeight)) {
383             throw new IllegalArgumentException("the parameter `matrix` is invalid!");
384         }
385 
386         int rowIndex = 0;
387         for (U row : matrix) {
388             int columnIndex = 0;
389             for (T element : row) {
390                 results[columnIndex][rowIndex] = element;
391                 columnIndex++;
392             }
393             rowIndex++;
394         }
395 
396         return rowFactory.apply(Iterator.of(results).map(columnFactory));
397     }
398 
399     @SuppressWarnings("unchecked")
400     static <T> IterableWithSize<T> withSize(Iterable<? extends T> iterable) {
401         return isTraversableAgain(iterable) ? withSizeTraversable(iterable) : withSizeTraversable(List.ofAll(iterable));
402     }
403 
404     private static <T> IterableWithSize<T> withSizeTraversable(Iterable<? extends T> iterable) {
405         if (iterable instanceof Collection) {
406             return new IterableWithSize<>(iterable, ((Collection<?>) iterable).size());
407         } else {
408             return new IterableWithSize<>(iterable, ((Traversable<?>) iterable).size());
409         }
410     }
411     
412     static class IterableWithSize<T> {
413         private final Iterable<? extends T> iterable;
414         private final int size;
415 
416         IterableWithSize(Iterable<? extends T> iterable, int size) {
417             this.iterable = iterable;
418             this.size = size;
419         }
420 
421         java.util.Iterator<? extends T> iterator() {
422             return iterable.iterator();
423         }
424 
425         java.util.Iterator<? extends T> reverseIterator() {
426             return Collections.reverseIterator(iterable);
427         }
428 
429         int size() {
430             return size;
431         }
432 
433         @SuppressWarnings("unchecked")
434         Object[] toArray() {
435             if (iterable instanceof Collection<?>) {
436                 return ((Collection<? extends T>) iterable).toArray();
437             } else {
438                 return ArrayType.asArray(iterator(), size());
439             }
440         }
441     }
442 
443 }